{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Question Answering Using Milvus and Hugging Face\n",
    "In this notebook we go over how to search for the best answer to questions using Milvus as the Vector Database and Hugging Face as the embedding system."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Packages\n",
    "We first begin with importing the required packages. In this example, the only non-builtin packages are datasets, transformers, and pymilvus. Transformers and datasets are the Hugging Face packages to create the pipeline and pymilvus is the client for Milvus. If not present on your system, these packages can be installed using `pip install transformers datasets pymilvus`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/filiphaltmayer/miniconda3/envs/openai/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "2023-02-10 15:59:27.832257: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA\n",
      "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
     ]
    }
   ],
   "source": [
    "from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility\n",
    "from datasets import load_dataset_builder, load_dataset, Dataset\n",
    "from transformers import AutoTokenizer, AutoModel"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters\n",
    "Here we can find the main parameters that need to be modified for running with your own accounts. Beside each is a description of what it is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATASET = 'squad'  # Huggingface Dataset to use\n",
    "MODEL = 'bert-base-uncased'  # Transformer to use for embeddings\n",
    "TOKENIZATION_BATCH_SIZE = 1000  # Batch size for tokenizing operaiton\n",
    "INFERENCE_BATCH_SIZE = 64  # batch size for transformer\n",
    "INSERT_RATIO = .001  # How many titles to embed and insert\n",
    "COLLECTION_NAME = 'huggingface_db'  # Collection name\n",
    "DIMENSION = 768  # Embeddings size\n",
    "LIMIT = 10  # How many results to search for\n",
    "HOST = 'localhost'  # IP for Milvus\n",
    "PORT = 19530"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Milvus\n",
    "This segment deals with Milvus and setting up the database for this use case. Within Milvus we need to setup a collection and index the collection. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Connect to Milvus Database\n",
    "connections.connect(host=HOST, port=PORT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Remove collection if it already exists\n",
    "if utility.has_collection(COLLECTION_NAME):\n",
    "    utility.drop_collection(COLLECTION_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create collection which includes the id, title, and embedding.\n",
    "fields = [\n",
    "    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),\n",
    "    FieldSchema(name='original_question', dtype=DataType.VARCHAR, max_length=1000),\n",
    "    FieldSchema(name='answer', dtype=DataType.VARCHAR, max_length=1000),\n",
    "    FieldSchema(name='original_question_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)\n",
    "]\n",
    "schema = CollectionSchema(fields=fields)\n",
    "collection = Collection(name=COLLECTION_NAME, schema=schema)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a FAISS index for collection.\n",
    "index_params = {\n",
    "    'metric_type':'L2',\n",
    "    'index_type':\"IVF_FLAT\",\n",
    "    'params':{\"nlist\":1536}\n",
    "}\n",
    "collection.create_index(field_name=\"original_question_embedding\", index_params=index_params)\n",
    "collection.load()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Insert Data\n",
    "Once we have the collection setup we need to start inserting our data. This is done in three steps: tokenizing the original question, embedding the tokenized question, and inserting the embedding, original question, and answer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset squad (/Users/filiphaltmayer/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)\n",
      "100%|██████████| 99/99 [00:00<00:00, 1655.16ex/s]\n"
     ]
    }
   ],
   "source": [
    "data_dataset = load_dataset(DATASET, split='all')\n",
    "data_dataset = data_dataset.train_test_split(test_size=INSERT_RATIO)['test']\n",
    "# Clean up the data structure in the dataset.\n",
    "data_dataset = data_dataset.map(lambda val: {'answer': val['answers']['text'][0]}, remove_columns=['answers'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:01<00:00,  1.28s/ba]\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(MODEL)\n",
    "\n",
    "# Tokenize the question into the format that bert takes.\n",
    "def tokenize_question(batch):\n",
    "    results = tokenizer(batch['question'], add_special_tokens = True, truncation = True, padding = \"max_length\", return_attention_mask = True, return_tensors = \"pt\")\n",
    "    batch['input_ids'] = results['input_ids']\n",
    "    batch['token_type_ids'] = results['token_type_ids']\n",
    "    batch['attention_mask'] = results['attention_mask']\n",
    "    return batch\n",
    "\n",
    "# Generate the tokens for each entry.\n",
    "data_dataset = data_dataset.map(tokenize_question, batch_size=TOKENIZATION_BATCH_SIZE, batched=True)\n",
    "# Set the ouput format to torch so it can be pushed into embedding model\n",
    "data_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'], output_all_columns=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "100%|██████████| 2/2 [04:21<00:00, 130.94s/ba]\n"
     ]
    }
   ],
   "source": [
    "model = AutoModel.from_pretrained(MODEL)\n",
    "# Embed the tokenized question and use the CLS embedding as the sentnece embedding.\n",
    "def embed(batch):\n",
    "    batch['question_embedding'] = model(\n",
    "                input_ids = batch['input_ids'],\n",
    "                token_type_ids=batch['token_type_ids'],\n",
    "                attention_mask = batch['attention_mask']\n",
    "                )[0][:,0,:]\n",
    "    return batch\n",
    "\n",
    "data_dataset = data_dataset.map(embed, remove_columns=['input_ids', 'token_type_ids', 'attention_mask'], batched = True, batch_size=INFERENCE_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Parameter 'function'=<function insert_function at 0x7fb6b19fa280> of the transform datasets.arrow_dataset.Dataset._map_single couldn't be hashed properly, a random hash was used instead. Make sure your transforms and parameters are serializable with pickle or dill for the dataset fingerprinting and caching to work. If you reuse this transform, the caching mechanism will consider it to be different from the previous calls and recompute everything. This warning is only showed once. Subsequent hashing failures won't be showed.\n",
      "100%|██████████| 2/2 [00:00<00:00,  8.11ba/s]\n"
     ]
    }
   ],
   "source": [
    "def insert_function(batch):\n",
    "    insertable = [\n",
    "        batch['question'],\n",
    "        [x[:995] + '...' if len(x) > 999 else x for x in batch['answer']],\n",
    "        batch['question_embedding'].tolist()\n",
    "    ]    \n",
    "    collection.insert(insertable)\n",
    "\n",
    "data_dataset.map(insert_function, batched=True, batch_size=64)\n",
    "collection.flush()"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Asks questions\n",
    "Once all the data is inserted and indexed within Milvus, we can ask questions and see what the closest answers are."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1/1 [00:00<00:00, 30.11ba/s]\n",
      "100%|██████████| 1/1 [00:03<00:00,  3.89s/ba]\n"
     ]
    }
   ],
   "source": [
    "questions = {'question':['When was chemistry invented?', 'When was Eisenhower born?']}\n",
    "question_dataset = Dataset.from_dict(questions)\n",
    "\n",
    "question_dataset = question_dataset.map(tokenize_question, batched = True, batch_size=TOKENIZATION_BATCH_SIZE)\n",
    "question_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask'], output_all_columns=True)\n",
    "question_dataset = question_dataset.map(embed, remove_columns=['input_ids', 'token_type_ids', 'attention_mask'], batched = True, batch_size=INFERENCE_BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 2/2 [00:00<00:00, 49.78ba/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Question:\n",
      "When was chemistry invented?\n",
      "Answer, Distance, Original Question\n",
      "('the sixteenth through the eighteenth centuries', tensor(6.7038), 'When did modern chemistry come into existence?')\n",
      "('1911', tensor(19.4356), 'In what year was phenobarbital discovered?')\n",
      "(\"the Architects' Registration Council of the United Kingdom\", tensor(20.2873), 'What organization was the Royal Institute instrumental in establishing?')\n",
      "('surprise attack to regain part of the Sinai territory Israel had captured 6 years earlier', tensor(20.4047), 'What was the October War?')\n",
      "('Internet service providers', tensor(20.4992), 'Who did early court cases focus on?')\n",
      "('19th', tensor(21.2683), 'In what century was Eisenhower born?')\n",
      "('meaning \"all port\" due to the shape of its coast.', tensor(21.4290), 'Why did the Greeks name Palermo Panormos?')\n",
      "('the Korean border', tensor(21.8102), 'Where did the Chinese military deploy troops in preparation for the arrival of US troops?')\n",
      "('the printing press', tensor(22.5804), 'What spread the use of texts by Galen within universities?')\n",
      "('talking monetary policy with the market', tensor(23.7503), 'What does \"open mouth operations\" mean?')\n",
      "\n",
      "Question:\n",
      "When was Eisenhower born?\n",
      "Answer, Distance, Original Question\n",
      "('19th', tensor(8.2841), 'In what century was Eisenhower born?')\n",
      "('the sixteenth through the eighteenth centuries', tensor(20.3792), 'When did modern chemistry come into existence?')\n",
      "('1790', tensor(21.3025), 'In what year was the Brown Fellowship Society created?')\n",
      "('1911', tensor(22.0691), 'In what year was phenobarbital discovered?')\n",
      "('surprise attack to regain part of the Sinai territory Israel had captured 6 years earlier', tensor(22.7447), 'What was the October War?')\n",
      "('1997', tensor(25.5775), 'What year was the second generation of digimon released?')\n",
      "('John Williams', tensor(27.5747), 'What composer has Steven Spielberg been associated with since 1974?')\n",
      "(\"the Architects' Registration Council of the United Kingdom\", tensor(28.0410), 'What organization was the Royal Institute instrumental in establishing?')\n",
      "('talking monetary policy with the market', tensor(28.2592), 'What does \"open mouth operations\" mean?')\n",
      "('the Korean border', tensor(28.8382), 'Where did the Chinese military deploy troops in preparation for the arrival of US troops?')\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "search_params = {\n",
    "    'nprobe': 128\n",
    "}\n",
    "\n",
    "def search(batch):\n",
    "    res = collection.search(batch['question_embedding'].tolist(), anns_field='original_question_embedding', param = search_params, output_fields=['answer', 'original_question'], limit = LIMIT)\n",
    "    overall_id = []\n",
    "    overall_distance = []\n",
    "    overall_answer  = []\n",
    "    overall_original_question = []\n",
    "    for hits in res:\n",
    "        ids = []\n",
    "        distance = []\n",
    "        answer = []\n",
    "        original_question = []\n",
    "        for hit in hits:\n",
    "            ids.append(hit.id)\n",
    "            distance.append(hit.distance)\n",
    "            answer.append(hit.entity.get('answer'))\n",
    "            original_question.append(hit.entity.get('original_question'))\n",
    "        overall_id.append(ids)\n",
    "        overall_distance.append(distance)\n",
    "        overall_answer.append(answer)\n",
    "        overall_original_question.append(original_question)\n",
    "    return {\n",
    "        'id': overall_id,\n",
    "        'distance': overall_distance,\n",
    "        'answer': overall_answer,\n",
    "        'original_question': overall_original_question\n",
    "    }\n",
    "question_dataset = question_dataset.map(search, batched=True, batch_size = 1)\n",
    "for x in question_dataset:\n",
    "    print()\n",
    "    print('Question:')\n",
    "    print(x['question'])\n",
    "    print('Answer, Distance, Original Question')\n",
    "    for x in zip(x['answer'], x['distance'], x['original_question']):\n",
    "        print(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "openai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "20e2a1b77e7395ec3f747af99b3084257dd9a83ab453e4f5fc77b9434eecfeb0"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
